{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\"\"\"Information Retrieval metrics\n",
    "\n",
    "Useful Resources:\n",
    "\n",
    "http://www.cs.utexas.edu/~mooney/ir-course/slides/Evaluation.ppt\n",
    "\n",
    "http://www.nii.ac.jp/TechReports/05-014E.pdf\n",
    "\n",
    "http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf\n",
    "\n",
    "http://hal.archives-ouvertes.fr/docs/00/72/67/60/PDF/07-busa-fekete.pdf\n",
    "\n",
    "Learning to Rank for Information Retrieval (Tie-Yan Liu)\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "https://gist.github.com/bwhite/3726239\n",
    "\n",
    "### Mean reciprocal rank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def mean_reciprocal_rank(rs):\n",
    "    \"\"\"Score is reciprocal of the rank of the first relevant item\n",
    "    First element is 'rank 1'.  Relevance is binary (nonzero is relevant).\n",
    "    Example from http://en.wikipedia.org/wiki/Mean_reciprocal_rank\n",
    "    >>> rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]]\n",
    "    >>> mean_reciprocal_rank(rs)\n",
    "    0.61111111111111105\n",
    "    >>> rs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]])\n",
    "    >>> mean_reciprocal_rank(rs)\n",
    "    0.5\n",
    "    >>> rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]]\n",
    "    >>> mean_reciprocal_rank(rs)\n",
    "    0.75\n",
    "    Args:\n",
    "        rs: Iterator of relevance scores (list or numpy) in rank order\n",
    "            (first element is the first item)\n",
    "    Returns:\n",
    "        Mean reciprocal rank\n",
    "    \"\"\"\n",
    "    rs = (np.asarray(r).nonzero()[0] for r in rs)\n",
    "    return np.mean([1./(r[0] + 1) if r.size else 0. for r in rs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.611111111111111"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]]\n",
    "mean_reciprocal_rank(rs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rs = [[0 , 0, 0], [0, 1, 0], [1, 0, 0]]\n",
    "mean_reciprocal_rank(rs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.75"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]]\n",
    "mean_reciprocal_rank(rs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### R Precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def r_precision(r):\n",
    "     \"\"\"Score is precision after all relevant documents have been retrieved\n",
    "    Relevance is binary (nonzero is relevant).\n",
    "    >>> r = [0, 0, 1]\n",
    "    >>> r_precision(r)\n",
    "    0.33333333333333331\n",
    "    >>> r = [0, 1, 0]\n",
    "    >>> r_precision(r)\n",
    "    0.5\n",
    "    >>> r = [1, 0, 0]\n",
    "    >>> r_precision(r)\n",
    "    1.0\n",
    "    Args:\n",
    "        r: Relevance scores (list or numpy) in rank order\n",
    "            (first element is the first item)\n",
    "    Returns:\n",
    "        R Precision\n",
    "    \"\"\"\n",
    "    r = np.asarray(r) != 0\n",
    "    z = r.nonzero()[0]\n",
    "    if not z.size:\n",
    "        return 0.\n",
    "    return np.mean(r[:z[-1] + 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3333333333333333"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [0, 0, 1]\n",
    "r_precision(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [0, 1, 0]\n",
    "r_precision(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [1, 0, 0]\n",
    "r_precision(r)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Precision @ k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def precision_at_k(r, k):\n",
    "    \"\"\"Score is precision @ k\n",
    "    Relevance is binary (nonzero is relevant).\n",
    "    >>> r = [0, 0, 1]\n",
    "    >>> precision_at_k(r, 1)\n",
    "    0.0\n",
    "    >>> precision_at_k(r, 2)\n",
    "    0.0\n",
    "    >>> precision_at_k(r, 3)\n",
    "    0.33333333333333331\n",
    "    >>> precision_at_k(r, 4)\n",
    "    Traceback (most recent call last):\n",
    "        File \"<stdin>\", line 1, in ?\n",
    "    ValueError: Relevance score length < k\n",
    "    Args:\n",
    "        r: Relevance scores (list or numpy) in rank order\n",
    "            (first element is the first item)\n",
    "    Returns:\n",
    "        Precision @ k\n",
    "    Raises:\n",
    "        ValueError: len(r) must be >= k\n",
    "    \"\"\"\n",
    "    assert k >= 1\n",
    "    r = np.asarray(r)[:k] != 0\n",
    "    if r.size != k:\n",
    "        raise ValueError('Relevance sort length < k')\n",
    "    return np.mean(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [0, 0, 1]\n",
    "precision_at_k(r, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_at_k(r, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3333333333333333"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_at_k(r, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [1, 0, 0]\n",
    "precision_at_k(r, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_at_k(r, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3333333333333333"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_at_k(r, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [0, 1, 0]\n",
    "precision_at_k(r, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_at_k(r, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3333333333333333"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_at_k(r, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Relevance sort length < k",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-40-aa7bfb3f9958>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mprecision_at_k\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-25-1c6987048039>\u001b[0m in \u001b[0;36mprecision_at_k\u001b[1;34m(r, k)\u001b[0m\n\u001b[0;32m      3\u001b[0m     \u001b[0mr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mk\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[0mk\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m         \u001b[1;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Relevance sort length < k'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      6\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mValueError\u001b[0m: Relevance sort length < k"
     ]
    }
   ],
   "source": [
    "precision_at_k(r, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Average Precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def average_precision(r):\n",
    "    \"\"\"Score is average precision (area under PR curve)\n",
    "    Relevance is binary (nonzero is relevant).\n",
    "    >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]\n",
    "    >>> delta_r = 1. / sum(r)\n",
    "    >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y])\n",
    "    0.7833333333333333\n",
    "    >>> average_precision(r)\n",
    "    0.78333333333333333\n",
    "    Args:\n",
    "        r: Relevance scores (list or numpy) in rank order\n",
    "            (first element is the first item)\n",
    "    Returns:\n",
    "        Average precision\n",
    "    \"\"\"\n",
    "    r = np.asarray(r) != 0\n",
    "    out = [precision_at_k(r, k+1) for k in range(r.size) if r[k]]\n",
    "    if not out:\n",
    "        return 0.\n",
    "    return np.mean(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7833333333333333"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]\n",
    "average_precision(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7833333333333333"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "delta_r = 1./ sum(r) \n",
    "sum([sum(r[: x+1]) / (x+1.) * delta_r for x, y in enumerate(r) if y])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.2"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "delta_r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.2, 0.2, 0.15000000000000002, 0.13333333333333333, 0.1]"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[sum(r[: x+1]) / (x+1.) * delta_r for x, y in enumerate(r) if y]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mean Average Precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_average_precision(rs):\n",
    "    \"\"\"Score is mean average precision\n",
    "    Relevance is binary (nonzero is relevant).\n",
    "    >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]]\n",
    "    >>> mean_average_precision(rs)\n",
    "    0.78333333333333333\n",
    "    >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]]\n",
    "    >>> mean_average_precision(rs)\n",
    "    0.39166666666666666\n",
    "    Args:\n",
    "        rs: Iterator of relevance scores (list or numpy) in rank order\n",
    "            (first element is the first item)\n",
    "    Returns:\n",
    "        Mean average precision\n",
    "    \"\"\"\n",
    "    return np.mean([average_precision(r) for r in rs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7833333333333333"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]]\n",
    "mean_average_precision(rs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.39166666666666666"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]]\n",
    "mean_average_precision(rs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Discounted Cumulative Gain (DCG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dcg_at_k(r, k, method=0):\n",
    "    \"\"\"Score is discounted cumulative gain (dcg)\n",
    "    Relevance is positive real values. Can use binary\n",
    "    as the previous methods.\n",
    "    Example from\n",
    "    http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf\n",
    "    >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]\n",
    "    >>> dcg_at_k(r, 1)\n",
    "    3.0\n",
    "    >>> dcg_at_k(r, 1, method=1)\n",
    "    3.0\n",
    "    >>> dcg_at_k(r, 2)\n",
    "    5.0\n",
    "    >>> dcg_at_k(r, 2, method=1)\n",
    "    4.2618595071429155\n",
    "    >>> dcg_at_k(r, 10)\n",
    "    9.6051177391888114\n",
    "    >>> dcg_at_k(r, 11)\n",
    "    9.6051177391888114\n",
    "    Args:\n",
    "        r: Relevance scores (list or numpy) in rank order\n",
    "            (first element is the first item)\n",
    "        k: Number of results to consider\n",
    "        method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...]\n",
    "                If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...]\n",
    "    Returns:\n",
    "        Discounted cumulative gain\n",
    "    \"\"\"\n",
    "    r = np.asfarray(r)[:k]\n",
    "    if r.size:\n",
    "        if method == 0:\n",
    "            return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))\n",
    "        elif method == 1:\n",
    "            return np.sum(r / np.log2(np.arange(2, r.size + 2)))\n",
    "        else:\n",
    "            raise ValueError('method must be 0 or 1.')\n",
    "    return 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.0"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]\n",
    "dcg_at_k(r, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.0"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dcg_at_k(r, 1, method=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.0"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dcg_at_k(r, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.2618595071429155"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dcg_at_k(r, 2, method=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9.605117739188811"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dcg_at_k(r, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9.605117739188811"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dcg_at_k(r, 11)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Normalized Discounted Cumulative Gain (NDCG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ndcg_at_k(r, k, method=0):\n",
    "    \"\"\"Score is normalized discounted cumulative gain (ndcg)\n",
    "    Relevance is positive real values.  Can use binary\n",
    "    as the previous methods.\n",
    "    Example from\n",
    "    http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf\n",
    "    >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]\n",
    "    >>> ndcg_at_k(r, 1)\n",
    "    1.0\n",
    "    >>> r = [2, 1, 2, 0]\n",
    "    >>> ndcg_at_k(r, 4)\n",
    "    0.9203032077642922\n",
    "    >>> ndcg_at_k(r, 4, method=1)\n",
    "    0.96519546960144276\n",
    "    >>> ndcg_at_k([0], 1)\n",
    "    0.0\n",
    "    >>> ndcg_at_k([1], 2)\n",
    "    1.0\n",
    "    Args:\n",
    "        r: Relevance scores (list or numpy) in rank order\n",
    "            (first element is the first item)\n",
    "        k: Number of results to consider\n",
    "        method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...]\n",
    "                If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...]\n",
    "    Returns:\n",
    "        Normalized discounted cumulative gain\n",
    "    \"\"\"\n",
    "    dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)\n",
    "    if not dcg_max:\n",
    "        return 0.\n",
    "    return dcg_at_k(r, k, method) / dcg_max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]\n",
    "ndcg_at_k(r, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9203032077642922"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = [2, 1, 2, 0]\n",
    "ndcg_at_k(r, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9651954696014428"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ndcg_at_k(r, 4, method = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ndcg_at_k([0], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ndcg_at_k([1], 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
